import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import kneighbors_graph
import cvxpy as cp

class CVX(object):
    """
    verify the correctness of ADMM which is used to obtain 'Z'.
    """
    def __init__(self, givenYtn, givenGt, givenEta, givenGamma, givenNumberOfFeatures):
        self.__diffMat = QMatrix(givenNumberOfFeatures=givenNumberOfFeatures).get()
        self.__d = givenNumberOfFeatures
        self.__gt = givenGt
        self.__eta = givenEta
        self.__gamma = givenGamma
        self.__Ytn = givenYtn
        self.__V = np.zeros(self.__d)
        pass

    def execute(self):
        Y = cp.Variable(self.__d)
        problem = cp.Problem(cp.Minimize( cp.norm1(self.__diffMat @ (Y - self.__Ytn)) + (cp.norm2(Y - (self.__Ytn - self.__eta*self.__gt)) ** 2) / (2*self.__eta*self.__gamma) ))
        problem.solve()
        self.__V = Y.value

    def getV(self):
        return self.__V
    pass

class AuxiliaryMatrix(object):
    """
    The auxiliary matrix: M or N. For example, M looks like:
    1 0 0
    0 1 0
    0 0 0
    0 0 1
    0 0 0
    0 0 0
    N looks like:
    0 0 0
    0 0 0
    1 0 0
    0 0 0
    0 1 0
    0 0 1
    """
    def __init__(self, givenFeatureIdList, givenNumOfFeatures):
        """
        givenFeatureIdList looks like: [1, 5, 8]
        """
        self.__featureIdList = givenFeatureIdList
        self.__numOfRows = givenNumOfFeatures
        self.__numOfCols = len(self.__featureIdList)
        self.__auxiliaryMat = self.__generate()

    def __generate(self):
        I = np.eye(self.__numOfRows)
        mat = np.zeros((self.__numOfRows, self.__numOfCols))
        for j in range(self.__numOfCols):
            mat[:,j] = I[:,self.__featureIdList[j]]
            pass
        return mat
    
    def get(self):
        return self.__auxiliaryMat
    pass

class MMatrix(object):
    """
    The auxiliary matrix: M. For example, M looks like:
    1 0 0
    0 1 0
    0 0 0
    0 0 1
    0 0 0
    0 0 0
    """
    def __init__(self, givenFeatureIdListOfSharingComponent, givenNumOfFeatures):
        """
        givenFeatureIdList looks like: [1, 5, 8]
        """
        self.__M = AuxiliaryMatrix(givenFeatureIdList=givenFeatureIdListOfSharingComponent, givenNumOfFeatures=givenNumOfFeatures).get()

    def get(self):
        return self.__M
    pass

class NMatrix(object):
    """
    The auxiliary matrix: N. For example, N looks like:
    0 0 0
    0 0 0
    1 0 0
    0 0 0
    0 1 0
    0 0 1
    """
    def __init__(self, givenFeatureIdListOfPersonalizedComponent, givenNumOfFeatures):
        """
        givenFeatureIdList looks like: [1, 5, 8]
        """
        self.__N = AuxiliaryMatrix(givenFeatureIdList=givenFeatureIdListOfPersonalizedComponent, givenNumOfFeatures=givenNumOfFeatures).get()

    def get(self):
        return self.__N
    pass

class QQMatrix(object):
    """
    symetric and positive defined matrix. For example the auxiliary matrix may be 'Q^\top Q' or 'Q Q^\top'.
    """
    def __init__(self, givenQMatrix, givenTransposeOperatorLoc):
        if givenTransposeOperatorLoc == 'left':
            self.__mat = givenQMatrix.T @ givenQMatrix
        if givenTransposeOperatorLoc == 'right':
            self.__mat = givenQMatrix @ givenQMatrix.T
        eigValue, eigVec = np.linalg.eig(self.__mat) # mat = P * Sigma * P^{-1}
        self.__sigma = eigValue
        self.__P = eigVec
        pass
    
    def get(self):
        return self.__mat
    
    def getEigenValue(self):
        return self.__sigma
    
    def getEigenVec(self):
        return self.__P
    pass

class QMatrix(object):
    """
    The difference matrix: 'Q' (N by M), which is generated by graph G. 
    """
    def __init__(self, givenGraph):
        assert(isinstance(givenGraph, GraphGenerater))
        self.__graph = givenGraph
        self.__diffMat = self.__initDiffMat()
    
    def __initDiffMat(self):
        """
        Q: N nodes by M edges.
        """
        adjacentMat = self.__graph.getAdjacentMat()
        (N, N) = adjacentMat.shape
        M = sum(sum(adjacentMat))
        diffVecList = []
        for i in range(N):
            for j in range(N):
                if adjacentMat[i,j] == 1:
                    vec = np.zeros(N)
                    vec[i], vec[j] = 1, -1
                    diffVecList.append(vec)
        diffMat = np.transpose(np.array(diffVecList))            
        return diffMat

    def get(self):
        return self.__diffMat
    pass

class SketchMatrixForSimilarity(object):
    def __init__(self, givenSketchList):
        """
        givenSketchList: N by ?, every row represents a data sketch of a node.
        """
        self.__sketchMat = givenSketchList
        pass

    def get(self):
        return self.__sketchMat
    pass

class CovarianceMatrixForSimilarity(object):
    def __init__(self, givenLocalDataList):
        """
        covarianceMat: N by ?, every row represents a covariance of node.
        """
        self.__localDataList = givenLocalDataList
        self.__covarianceMat = self.__init()
        pass

    def __init(self):
        pass
    
    def get(self):
        return self.__covarianceMat
    pass

class ModelMatrixForSimilarity(object):
    def __init__(self, givenLocalDataList):
        """
        covarianceMat: N by ?, every row represents a local model of node.
        """
        self.__localDataList = givenLocalDataList
        self.__modelMat = self.__init()
        pass

    def __init(self):
        pass
    
    def get(self):
        return self.__modelMat
    pass

class GraphGenerater(object):
    def __init__(self, givenNumOfNode, givenSimilaritySpace):
        self.__similaritySpace = givenSimilaritySpace
        self.__N = givenNumOfNode
        self.__numOfNeighbors = 2
        self.__adjacentMat = self.generateAdjacentMat()
        pass

    def generateAdjacentMat(self):
        adjacentMat = kneighbors_graph(self.__similaritySpace.get(), self.__numOfNeighbors, mode = 'connectivity')
        return adjacentMat

    def getAdjacentMat(self):
        return self.__adjacentMat
    pass

class ADMM(object):
    """
    Optimize Z
    """
    def __init__(self, givenQ, givenNabla, givenNMat, givenEta, givenLambda):
        self.__Q = givenQ
        self.__N, self.__M = self.__Q.shape
        self.__Nabla = givenNabla # d by N
        self.__NMat = givenNMat # d by d2
        d, d2 = self.__NMat.shape
        QQ = QQMatrix(givenQMatrix=self.__Q, givenTransposeOperatorLoc='right')
        self.__P = QQ.getEigenVec()
        self.__sigma = QQ.getEigenValue()
        self.__eta = givenEta
        self.__lambda = givenLambda
        self.__Z, self.__W, self.__Omega = np.zeros((d2, self.__N)), np.zeros((d2, self.__M)), np.zeros((d2, self.__M))
        self.__K = 25
        self.execute()
        pass
    
    def __updateZ(self):
        left = self.__eta * (self.__Omega @ self.__Q.T - (self.__NMat.T @ self.__Nabla)/self.__N + self.__W @ self.__Q.T) + self.__Z
        right = self.__P @ np.diag(np.true_divide(1.0, 1+self.__eta*self.__sigma)) @ self.__P.T
        self.__Z = left @ right
        pass

    def __updateW(self):
        for m in range(self.__M):
            left = np.maximum(1 - np.true_divide(self.__lambda, np.linalg.norm(self.__Z @ self.__Q[:,m] - self.__Omega[:,m])),0)
            right = self.__Z @ self.__Q[:,m] - self.__Omega[:,m]
            self.__W[:,m] = left * right
        pass

    def __updateOmega(self):
        self.__Omega = self.__Omega + (self.__W - self.__Z @ self.__Q)
        pass

    def execute(self):
        for k in range(self.__K):
            # update Z
            self.__updateZ()
            # update W
            self.__updateW()
            # update Omega
            self.__updateOmega()
        pass

    def getZ(self):
        return self.__Z    
    pass

class FormulationSettings(object):
    def __init__(self, givenFeatureIdListOfSharingComponent, givenFeatureIdListOfPersonalizedComponent, givenNumberOfSamples, givenLambda):
        self.__d1, self.__d2 = len(givenFeatureIdListOfSharingComponent), len(givenFeatureIdListOfPersonalizedComponent)
        self.__d = self.__d1 + self.__d2
        self.__N = givenNumberOfSamples
        self.__MMat = MMatrix(givenFeatureIdListOfSharingComponent=givenFeatureIdListOfSharingComponent, givenNumOfFeatures=self.__d).get()
        self.__NMat = NMatrix(givenFeatureIdListOfPersonalizedComponent=givenFeatureIdListOfPersonalizedComponent, givenNumOfFeatures=self.__d).get()
        self.__lambda = givenLambda

    def getNumberOfSamples(self):
        return self.__N
    
    def getNumberOfFeatures(self):
        return self.__d

    def getNumberOfSharingFeatures(self):
        return self.__d1
    
    def getNumberOfPersonalizedFeatures(self):
        return self.__d2
    
    def getMMat(self):
        return self.__MMat
    
    def getNMat(self):
        return self.__NMat
    
    def getLambda(self):
        return self.__lambda
    pass

class AlternativeOptimizer(object):
    def __init__(self, givenInitX, givenInitZ, givenQ, givenFormulationSettings, givenEta, givenNabla):
        assert(isinstance(givenFormulationSettings, FormulationSettings))
        self.__Nabla = givenNabla # d by N, every column represents a node's local update, e.g. [g_t^1, g_t^2, g_t^3, g_t^4]
        self.__d, self.__N = self.__Nabla.shape
        self.__MMat, self.__NMat, self.__Q = givenFormulationSettings.getMMat(), givenFormulationSettings.getNMat(), givenQ
        self.__x = givenInitX # d1 by 1
        self.__Z = givenInitZ # d2 by N
        self.__eta = givenEta
        self.__lambda = givenFormulationSettings.getLambda()
        self.step()
        self.__Xn = np.einsum('i,j->ij', self.__MMat @ self.__x, np.ones(self.__N)) + self.__NMat @ self.__Z # d by N
        pass

    def step(self):
        # update X
        self.__updateX()
        # update Z
        self.__updateZ()
        pass
    
    def __updateX(self):
        self.__x = self.__x - self.__eta * (self.__MMat.T @ self.__Nabla) @ np.ones(self.__N) / self.__N 

    def __updateZ(self):
        admm = ADMM(givenQ=self.__Q, givenNabla=self.__Nabla, givenNMat=self.__NMat, givenEta=self.__eta, givenLambda=self.__lambda)
        self.__Z = admm.getZ()
    
    def getX(self):
        """
        sharing component: d1 by 1
        """
        return self.__x
    
    def getZ(self):
        return self.__Z
    
    def getXn(self):
        """
        personalized component: d by N
        """
        return self.__Xn
    pass

class ClusterPath(object):
    def __init__(self, givenPointCollections = [[[1,2],[-1,-2]], [[3,4],[-3,-4]], [[4,-5],[-4,5]]], givenSavedFigPath = ''):
        """
        givenPointCollections: numberOfPhase by numberOfClient by numberOfFeature
        [
          [
            [1,2],
            [-1,-2] # client
          ], # phase
        ]
        """
        self.__aux = self.__initHandler()
        self.__pointCollectionsList = givenPointCollections # [[pointList], [pointList]]
        self.__savedFigPath = givenSavedFigPath
        pass
    
    def __initHandler(self):
        fig = plt.figure() 
        ax1 = fig.add_subplot(1, 1, 1)
        return ax1

    def __plotLines(self, givenStartPointList, givenEndPointList, givenLineType = 'r-'):
        """
        givenStartPointList: [[x-coordinate, y-coordinate]]
        givenEndPointList: [[x-coordinate, y-coordinate]]
        """
        for index, (startPoint, endPoint) in enumerate(zip(givenStartPointList, givenEndPointList)):
            self.__aux.plot([startPoint[0], endPoint[0]], [startPoint[1], endPoint[1]], givenLineType, linewidth = 1)
        pass
    
    def __plotPoints(self, givenPointList):
        self.__aux.scatter(givenPointList[:,0], givenPointList[:,1], marker='o', facecolors='None', edgecolors='red')

    def plotFullPath(self):
        for pointListIndex, pointList in enumerate(self.__pointCollectionsList):
            if pointListIndex == 0:
                self.__plotPoints(givenPointList=pointList)
            if pointListIndex == len(self.__pointCollectionsList)-1:
                break
            self.__plotLines(givenStartPointList=self.__pointCollectionsList[pointListIndex], givenEndPointList=self.__pointCollectionsList[pointListIndex+1], givenLineType = '+r-')
        self.__aux.set_xlabel('First Principal Component') 
        self.__aux.set_ylabel('Second Principal Component') 
        self.__aux.legend([])
        plt.savefig(self.__savedFigPath) 
        plt.show()
        pass


    pass

class ToyDataOfConvexCluster(object):
    def __init__(self, givenNumberOfSamples = 100, givenNumberOfFeatures = 6):
        """
        dataMat: numberOfSamples by numberOfFeatures
        """
        self.__dataMat = np.random.randn(givenNumberOfSamples, givenNumberOfFeatures)
        self.__numOfSamples, self.__numOfFeatures = self.__dataMat.shape
        
    def getDataMat(self):
        return self.__dataMat
    
    def extractData(self, givenSampleIds):
        sampleList = [self.__dataMat[id] for id in givenSampleIds] 
        return np.array(sampleList)
    
    def querySample(self, givenSampleId):
        return self.__dataMat[givenSampleId]
    
    def getNumberOfSamples(self):
        return self.__numOfSamples
    
    def getNumberOfFeatures(self):
        return self.__numOfFeatures
    pass

class DataFederation(object):
    def __init__(self, givenDatasetObj, givenFederatedLearningSettings):
        assert(isinstance(givenFederatedLearningSettings, FederatedLearningSettings))
        self.__datasetObj = givenDatasetObj # numberOfSamples by numberOfFeatures
        self.__federatedLearningSettings = givenFederatedLearningSettings
        pass
    
    def generateSampleIdListForClients(self):
        sampleIdsOfClientList = []
        N = self.__federatedLearningSettings.getNumOfClients()
        for clienti in range(N):
            moduleOfSamples = int(np.true_divide(self.getNumberOfSamples(), N))
            startidForClient, endidForClient = clienti * moduleOfSamples, np.min([self.getNumberOfSamples(), (clienti+1) * moduleOfSamples])
            sampleIdsOfClientList.append(range(startidForClient, endidForClient))
        return sampleIdsOfClientList
    
    def generateQMat(self):
        N = self.__federatedLearningSettings.getNumOfClients()
        sampleIdListOfClient = self.generateSampleIdListForClients()
        dataMatListOfClient = []
        for clienti, sampleIdsOfClient in enumerate(sampleIdListOfClient):
            dataMatListOfClient.append(self.__datasetObj.extractData(givenSampleIds = sampleIdsOfClient))
        sketchList = [np.mean(dataMatListOfClient[clienti], axis = 0) for clienti in range(N)]
        QMat = QMatrix(givenGraph=GraphGenerater(givenNumOfNode = N, givenSimilaritySpace = SketchMatrixForSimilarity(givenSketchList=sketchList))).get() # N by M
        return QMat 
    
    def getDatasetObj(self):
        return self.__datasetObj
    
    def querySample(self, givenSampleId):
        return self.__datasetObj.querySample(givenSampleId)

    def getNumberOfSamples(self):
        return self.__datasetObj.getNumberOfSamples()
    
    def getNumberOfFeatures(self):
        return self.__datasetObj.getNumberOfFeatures()
    pass


class FederatedLearningSettings(object):
    def __init__(self, givenNumberOfClients = 5, givenNumberOfIterations = 100):
        self.__numOfClients = givenNumberOfClients
        self.__numOfIterations = givenNumberOfIterations

    def getNumOfClients(self):
        return self.__numOfClients
    
    def getNumOfIterations(self):
        return self.__numOfIterations
    pass

class PersonalizedModelSettings(object):
    def __init__(self, givenFeatureIdListOfSharingComponent, givenFeatureIdListOfPersonalizedComponent):
        self.__featureIdListOfSharingComponent = givenFeatureIdListOfSharingComponent
        self.__featureIdListOfPersonalizedComponent = givenFeatureIdListOfPersonalizedComponent
        pass
    
    def getFeatureIdListOfSharingComponent(self):
        return self.__featureIdListOfSharingComponent
    
    def getFeatureIdListOfPersonalizedComponent(self):
        return self.__featureIdListOfPersonalizedComponent
    pass

class PersonalizedConvexCluster(object):
    def __init__(self, givenDataFederation, givenPersonalizedModelSettings, givenFormulationSettings, givenFederatedLearningSettings):
        assert(isinstance(givenDataFederation, DataFederation))
        assert(isinstance(givenPersonalizedModelSettings, PersonalizedModelSettings))
        assert(isinstance(givenFormulationSettings, FormulationSettings))
        assert(isinstance(givenFederatedLearningSettings, FederatedLearningSettings))
        self.__dataFederation = givenDataFederation 
        self.__personalizedModelSettings = givenPersonalizedModelSettings
        self.__formulationSettings = givenFormulationSettings
        self.__federatedLearningSettings = givenFederatedLearningSettings
        self.__personalizedModelVecs = np.zeros((self.__formulationSettings.getNumberOfFeatures(), self.__federatedLearningSettings.getNumOfClients())) # numberOfFeatures by numberOfClients
        self.executeByFedAvG(givenNumOfIterations=self.__federatedLearningSettings.getNumOfIterations())
        pass
    
    def computeStochasticGradient(self, givenModelVec, givenSample):
        stochasticGrad = 1/self.__dataFederation.getDatasetObj().getNumberOfSamples() * (givenModelVec - givenSample)
        return stochasticGrad
 
    def computeAverageOfGradients(self, givenModelVec, givenSampleIdList):
        aveGrad = np.zeros(self.__dataFederation.getNumberOfFeatures())
        for index, sampleId in enumerate(givenSampleIdList):
            stocGrad = self.computeStochasticGradient(givenModelVec=givenModelVec, givenSample=self.__dataFederation.querySample(sampleId))
            aveGrad = aveGrad + stocGrad
        resultGrad = np.true_divide(aveGrad, len(givenSampleIdList))
        return resultGrad
    
    def computePersonalizedGradients(self, givenPersonalizedModelVecs):
        """
        givenPersonalizedModelVecs: d by N
        """
        stocGradList = []
        for clienti, sampleIdListOfClient in enumerate(self.__dataFederation.generateSampleIdListForClients()):
            stocGrad = self.computeAverageOfGradients(givenModelVec=givenPersonalizedModelVecs[:,clienti], givenSampleIdList=sampleIdListOfClient)
            stocGradList.append(stocGrad)
        stocGradMat = np.array(stocGradList)
        return stocGradMat # N by d
    
    def executeByFedAvG(self, givenNumOfIterations=500):
        featureIdListOfSharingComponent = self.__personalizedModelSettings.getFeatureIdListOfSharingComponent()
        featureIdListOfPersonalizedComponent = self.__personalizedModelSettings.getFeatureIdListOfPersonalizedComponent()
        x, Z = np.zeros(len(featureIdListOfSharingComponent)), np.zeros((len(featureIdListOfPersonalizedComponent), self.__federatedLearningSettings.getNumOfClients()))
        self.__personalizedModelVecs = np.einsum('i,j->ij', self.__formulationSettings.getMMat() @ x, np.ones(self.__federatedLearningSettings.getNumOfClients())) + self.__formulationSettings.getNMat() @ Z
        QMat = self.__dataFederation.generateQMat()
        for i in range(givenNumOfIterations):
            nablaMat = np.transpose(self.computePersonalizedGradients(givenPersonalizedModelVecs=self.__personalizedModelVecs))
            opti = AlternativeOptimizer(givenInitX = x, givenInitZ = Z, givenQ=QMat, givenFormulationSettings= self.__formulationSettings, 
                                        givenEta = 1e-2, givenNabla = nablaMat)
            x, Z = opti.getX(), opti.getZ()
            self.__personalizedModelVecs = np.einsum('i,j->ij', self.__formulationSettings.getMMat() @ x, np.ones(self.__federatedLearningSettings.getNumOfClients())) + self.__formulationSettings.getNMat() @ Z
        pass

    def getPersonalizedModels(self):
        return self.__personalizedModelVecs
    pass



class Test(object):
    def testOptimizer(self):
        datasetObj = ToyDataOfConvexCluster(givenNumberOfSamples=100, givenNumberOfFeatures=6)
        d = datasetObj.getNumberOfFeatures()
        featureIdListOfSharingComponent=range(0,d-2)
        featureIdListOfPersonalizedComponent=range(d-2, d)
        personalizedModelSettings = PersonalizedModelSettings(givenFeatureIdListOfSharingComponent=featureIdListOfSharingComponent,givenFeatureIdListOfPersonalizedComponent=featureIdListOfPersonalizedComponent)
        formulationSettings = FormulationSettings(givenFeatureIdListOfSharingComponent=featureIdListOfSharingComponent, givenFeatureIdListOfPersonalizedComponent=featureIdListOfPersonalizedComponent, givenNumberOfSamples=datasetObj.getNumberOfSamples(), givenLambda=1)
        federatedLearningSettings = FederatedLearningSettings(givenNumberOfClients=5, givenNumberOfIterations=100)
        pcc = PersonalizedConvexCluster(givenDatasetObject=datasetObj, givenPersonalizedModelSettings=personalizedModelSettings, 
                                        givenFormulationSettings=formulationSettings, givenFederatedLearningSettings=federatedLearningSettings)
        personalizedModels = pcc.getPersonalizedModels()
        return personalizedModels

    def testClusterPath(self):
        datasetObj = ToyDataOfConvexCluster(givenNumberOfSamples=100, givenNumberOfFeatures=6)
        d = datasetObj.getNumberOfFeatures()
        featureIdListOfSharingComponent=range(0,d-2)
        featureIdListOfPersonalizedComponent=range(d-2, d)
        personalizedModelSettings = PersonalizedModelSettings(givenFeatureIdListOfSharingComponent=featureIdListOfSharingComponent,givenFeatureIdListOfPersonalizedComponent=featureIdListOfPersonalizedComponent)
        federatedLearningSettings = FederatedLearningSettings(givenNumberOfClients=8, givenNumberOfIterations=200)
        dataFederation = DataFederation(givenDatasetObj=datasetObj, givenFederatedLearningSettings=federatedLearningSettings)
        personalizedCoords = []
        for lambdai, lambdav in enumerate([0, 1e-8, 1e-7, 1e-6, 1e-3]):
            formulationSettings = FormulationSettings(givenFeatureIdListOfSharingComponent=featureIdListOfSharingComponent, givenFeatureIdListOfPersonalizedComponent=featureIdListOfPersonalizedComponent, 
                                                      givenNumberOfSamples=datasetObj.getNumberOfSamples(), givenLambda=lambdav)
            pcc = PersonalizedConvexCluster(givenDataFederation=dataFederation, givenPersonalizedModelSettings=personalizedModelSettings, 
                                            givenFormulationSettings=formulationSettings, givenFederatedLearningSettings=federatedLearningSettings)
            personalizedModels = pcc.getPersonalizedModels()
            personalizedCoords.append(np.transpose(personalizedModels[-2:,:]))
        cp = ClusterPath(givenPointCollections=personalizedCoords)
        cp.plotFullPath()
    pass

np.random.seed(0)
Test().testClusterPath()
















